import numpy as np
import scipy.stats as spst
from numpy.testing import assert_approx_equal, assert_array_almost_equal_nulp

import openpnm as op
import openpnm.models.misc as mods


class MiscTest:

    def setup_class(self):
        self.net = op.network.Cubic(shape=[5, 5, 5])

    def test_constant(self):
        self.net.add_model(model=mods.constant,
                           propname='pore.value',
                           value=3.3)
        assert np.all(np.unique(self.net['pore.value']) == 3.3)

    def test_product(self):
        self.net.add_model(model=mods.constant,
                           propname='pore.value1',
                           value=2)
        self.net.add_model(model=mods.constant,
                           propname='pore.value2',
                           value=2)
        self.net.add_model(model=mods.product,
                           propname='pore.result1',
                           props=['pore.value1', 'pore.value2'])
        assert np.all(np.unique(self.net['pore.result1']) == 4)
        self.net.add_model(model=mods.constant,
                           propname='pore.value3',
                           value=2)
        self.net.add_model(model=mods.product,
                           propname='pore.result2',
                           props=['pore.value1', 'pore.value2', 'pore.value3'])
        assert np.all(np.unique(self.net['pore.result2']) == 8)

    def test_generic_function(self):
        self.net['pore.rand'] = np.random.rand(self.net.Np)
        self.net.add_model(model=mods.generic_function,
                           func=np.clip,
                           propname='pore.clipped',
                           prop='pore.rand',
                           a_min=0.2, a_max=0.8)
        assert np.amax(self.net['pore.clipped']) == 0.8
        assert np.amin(self.net['pore.clipped']) == 0.2

    def test_scaled(self):
        self.net['pore.value4'] = 4
        self.net.add_model(model=mods.scaled,
                           propname='pore.value5',
                           prop='pore.value4',
                           factor=2)
        assert np.all(np.unique(self.net['pore.value5']) == 8)

    def test_linear(self):
        self.net['pore.value4'] = 4
        self.net.add_model(model=mods.linear,
                           propname='pore.value6',
                           prop='pore.value4',
                           m=2, b=2)
        assert np.all(np.unique(self.net['pore.value6']) == 10)

    def test_polynomial(self):
        self.net['pore.value4'] = 4
        self.net.add_model(model=mods.polynomial,
                           propname='pore.value7',
                           prop='pore.value4',
                           a=[0, 2, 4, 6])
        assert np.all(np.unique(self.net['pore.value7']) == 456)

    def test_random_no_seed(self):
        self.net.add_model(model=mods.random,
                           propname='pore.seed',
                           element='pore',
                           seed=None)
        temp1 = self.net['pore.seed'].copy()
        self.net.regenerate_models()
        temp2 = self.net['pore.seed'].copy()
        assert np.all(~(temp1 == temp2))

    def test_random_with_seed(self):
        self.net.add_model(model=mods.random,
                           propname='pore.seed',
                           element='pore',
                           seed=0)
        temp1 = self.net['pore.seed'].copy()
        self.net.regenerate_models()
        temp2 = self.net['pore.seed'].copy()
        assert_array_almost_equal_nulp(temp1, temp2)

    def test_random_with_range(self):
        self.net.add_model(model=mods.random,
                           propname='pore.seed',
                           element='pore',
                           num_range=[0.1, 0.9])
        self.net.regenerate_models()
        assert np.amax(self.net['pore.seed']) <= 0.9
        assert np.amin(self.net['pore.seed']) >= 0.1

    def test_from_neighbor_throats_min(self):
        self.net.pop('pore.seed', None)
        self.net.models.pop('pore.seed', None)
        self.net.models.pop('throat.seed', None)
        self.net['throat.seed'] = np.linspace(0, 1, self.net.Nt)
        self.net.add_model(model=mods.from_neighbor_throats,
                           propname='pore.seed',
                           prop='throat.seed',
                           mode='min')
        assert np.all(np.in1d(self.net['pore.seed'], self.net['throat.seed']))
        assert np.isclose(self.net['throat.seed'].mean(), 0.5)
        assert np.isclose(self.net['pore.seed'].mean(), 0.16454849498327762)

    def test_from_neighbor_throats_max(self):
        self.net.pop('pore.seed', None)
        self.net.models.pop('pore.seed', None)
        self.net.models.pop('throat.seed', None)
        self.net['throat.seed'] = np.linspace(0, 1, self.net.Nt)
        self.net.add_model(model=mods.from_neighbor_throats,
                           propname='pore.seed',
                           prop='throat.seed',
                           mode='max')
        assert np.all(np.in1d(self.net['pore.seed'], self.net['throat.seed']))
        assert np.isclose(self.net['throat.seed'].mean(), 0.5)
        assert np.isclose(self.net['pore.seed'].mean(), 0.8595317725752508)

    def test_from_neighbor_throats_mean(self):
        self.net.pop('pore.seed', None)
        self.net.models.pop('pore.seed', None)
        self.net.models.pop('throat.seed', None)
        self.net['throat.seed'] = np.linspace(0, 1, self.net.Nt)
        self.net.add_model(model=mods.from_neighbor_throats,
                           propname='pore.seed',
                           prop='throat.seed',
                           mode='mean')
        assert np.isclose(self.net['throat.seed'].mean(), 0.5)
        assert np.isclose(self.net['pore.seed'].mean(), 0.5)

    def test_neighbor_pores_with_nans(self):
        net = op.network.Cubic(shape=[2, 2, 2])
        net['pore.values'] = 1.0
        net['pore.values'][0] = np.nan
        f = mods.from_neighbor_pores
        with_nans = f(net, prop='pore.values',
                      ignore_nans=False, mode='min')
        assert np.any(np.isnan(with_nans))
        no_nans = f(net, prop='pore.values',
                    ignore_nans=True, mode='min')
        assert np.all(~np.isnan(no_nans))
        with_nans = f(net, prop='pore.values',
                      ignore_nans=False, mode='max')
        assert np.any(np.isnan(with_nans))
        no_nans = f(net, prop='pore.values',
                    ignore_nans=True, mode='max')
        assert np.all(~np.isnan(no_nans))
        with_nans = f(net, prop='pore.values',
                      ignore_nans=False, mode='mean')
        assert np.any(np.isnan(with_nans))
        no_nans = f(net, prop='pore.values',
                    ignore_nans=True, mode='mean')
        assert np.all(~np.isnan(no_nans))

    def test_neighbor_throats_mode_min_with_nans(self):
        net = op.network.Cubic(shape=[2, 2, 2])
        net['throat.values'] = np.linspace(0, 1, net.Nt)
        net['throat.values'][0] = np.nan
        f = mods.from_neighbor_throats
        with_nans = f(net, prop='throat.values',
                      ignore_nans=False, mode='min')
        assert np.any(np.isnan(with_nans))
        no_nans = f(net, prop='throat.values',
                    ignore_nans=True, mode='min')
        assert np.all(~np.isnan(no_nans))
        assert np.all(~np.isinf(no_nans))
        assert np.allclose(no_nans, np.array([0.36363636, 0.45454545,
                                              0.09090909, 0.09090909,
                                              0.18181818, 0.18181818,
                                              0.27272727, 0.27272727]))

    def test_neighbor_throats_mode_max_with_nans(self):
        net = op.network.Cubic(shape=[2, 2, 2])
        net['throat.values'] = np.linspace(0, 1, net.Nt)
        net['throat.values'][0] = np.nan
        f = mods.from_neighbor_throats
        with_nans = f(net, prop='throat.values',
                      ignore_nans=False, mode='max')
        assert np.any(np.isnan(with_nans))
        no_nans = f(net, prop='throat.values',
                    ignore_nans=True, mode='max')
        assert np.all(~np.isnan(no_nans))
        assert np.all(~np.isinf(no_nans))
        assert np.allclose(no_nans, np.array([0.72727273, 0.81818182,
                                              0.90909091, 1.00000000,
                                              0.72727273, 0.81818182,
                                              0.90909091, 1.00000000]))

    def test_neighbor_throats_mode_mean_with_nans(self):
        net = op.network.Cubic(shape=[2, 2, 2])
        net['throat.values'] = np.linspace(0, 1, net.Nt)
        net['throat.values'][0] = np.nan
        f = mods.from_neighbor_throats
        with_nans = f(net, prop='throat.values',
                      ignore_nans=False, mode='mean')
        assert np.any(np.isnan(with_nans))
        no_nans = f(net, prop='throat.values',
                    ignore_nans=True, mode='mean')
        assert np.all(~np.isnan(no_nans))
        assert np.all(~np.isinf(no_nans))
        assert np.allclose(no_nans, np.array([0.54545455, 0.63636364,
                                              0.45454545, 0.51515152,
                                              0.48484848, 0.54545455,
                                              0.57575758, 0.63636364]))

    def test_from_neighbor_pores_min(self):
        del self.net['throat.seed']
        del self.net.models['throat.seed']
        self.net['pore.seed'] = np.random.rand(self.net.Np,)
        self.net.add_model(model=mods.from_neighbor_pores,
                           propname='throat.seed',
                           prop='pore.seed',
                           mode='min')
        P12 = self.net['throat.conns']
        tseed = np.amin(self.net['pore.seed'][P12], axis=1)
        assert_array_almost_equal_nulp(self.net['throat.seed'], tseed)

    def test_from_neighbor_pores_max(self):
        del self.net['throat.seed']
        del self.net.models['throat.seed']
        self.net['pore.seed'] = np.random.rand(self.net.Np,)
        self.net.add_model(model=mods.from_neighbor_pores,
                           propname='throat.seed',
                           prop='pore.seed',
                           mode='max')
        P12 = self.net['throat.conns']
        tseed = np.amax(self.net['pore.seed'][P12], axis=1)
        assert_array_almost_equal_nulp(self.net['throat.seed'], tseed)

    def test_from_neighbor_pores_mean(self):
        del self.net['throat.seed']
        del self.net.models['throat.seed']
        self.net['pore.seed'] = np.random.rand(self.net.Np,)
        self.net.add_model(model=mods.from_neighbor_pores,
                           propname='throat.seed',
                           prop='pore.seed',
                           mode='mean')
        P12 = self.net['throat.conns']
        tseed = np.mean(self.net['pore.seed'][P12], axis=1)
        assert_array_almost_equal_nulp(self.net['throat.seed'], tseed)

    def test_invert(self):
        net = op.network.Cubic(shape=[5, 5, 5])
        net['pore.diameter'] = 2.0
        net.add_model(propname='pore.entry_pressure',
                      prop='pore.diameter',
                      model=mods.invert)
        assert net['pore.entry_pressure'][0] == 0.5

    def test_match_histograms(self):
        net = op.network.Cubic(shape=[5, 5, 5])
        c = [0.1, 0.3, 0.8, 1.2]
        h = [5, 20, 20, 100]
        a = mods.match_histogram(net, bin_centers=c, bin_heights=h,
                                 element='pore')
        assert np.all(np.unique(a) == c)
        vals, nums = np.unique(a, return_counts=True)
        assert nums[3] == np.amax(nums)
        assert nums[0] == np.amin(nums)

    def test_generic_distribution(self):
        pn = op.network.Cubic(shape=[5, 5, 5])
        pn['pore.seed'] = np.random.rand(pn.Np)
        pn.add_model(propname='pore.test1',
                     model=op.models.misc.generic_distribution,
                     seeds='pore.seed',
                     func=spst.weibull_min,
                     c=2.8,
                     scale=1e-5)
        pn.add_model(propname='pore.test2',
                     model=op.models.misc.generic_distribution,
                     seeds='pore.seed',
                     func=spst.weibull_min(c=2.8, scale=1e-5))
        assert np.all(pn['pore.test1'] == pn['pore.test2'])
        pn.models['pore.test1@all']['c'] = 1.0
        pn.regenerate_models()
        assert np.all(pn['pore.test1'] != pn['pore.test2'])


if __name__ == '__main__':

    t = MiscTest()
    self = t
    t.setup_class()
    for item in t.__dir__():
        if item.startswith('test'):
            print(f"Running test: {item}")
            t.__getattribute__(item)()
